Skip to content

[Feat] Adds LongCat-AudioDiT pipeline #13390

Open
RuixiangMa wants to merge 14 commits intohuggingface:mainfrom
RuixiangMa:longcataudiodit
Open

[Feat] Adds LongCat-AudioDiT pipeline #13390
RuixiangMa wants to merge 14 commits intohuggingface:mainfrom
RuixiangMa:longcataudiodit

Conversation

@RuixiangMa
Copy link
Copy Markdown

@RuixiangMa RuixiangMa commented Apr 2, 2026

What does this PR do?

Adds LongCat-AudioDiT model support to diffusers.

Although LongCat-AudioDiT can be used for TTS-like generation, it is fundamentally a diffusion-based audio generation model (text conditioning + iterative latent denoising + VAE decoding) rather than a conventional autoregressive TTS model, so i think it fits naturally into diffusers.

Test

import soundfile as sf
import torch
from diffusers import LongCatAudioDiTPipeline

pipeline = LongCatAudioDiTPipeline.from_pretrained(
    "meituan-longcat/LongCat-AudioDiT-1B",
    torch_dtype=torch.float16,
)
pipeline = pipeline.to("cuda")

audio = pipeline(
    prompt="A calm ocean wave ambience with soft wind in the background.",
    audio_end_in_s=5.0,
    num_inference_steps=16,
    guidance_scale=4.0,
    output_type="pt",
).audios

output = audio[0, 0].float().cpu().numpy()
sf.write("longcat.wav", output, pipeline.sample_rate)

Result

longcat.wav

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa RuixiangMa changed the title Longcataudiodit [Feat] Adds LongCat-AudioDiT support Apr 2, 2026
@RuixiangMa RuixiangMa changed the title [Feat] Adds LongCat-AudioDiT support [Feat] Adds LongCat-AudioDiT pipeline Apr 2, 2026
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@dg845 dg845 requested review from dg845 and yiyixuxu April 4, 2026 00:31
)


def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I think we should inline _pixel_shuffle_1d in UpsampleShortcut following #13390 (comment).

Comment on lines +515 to +519
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(
self.time_embed = AudioDiTTimestepEmbedding(dim)
self.input_embed = AudioDiTEmbedder(latent_dim, dim)
self.text_embed = AudioDiTEmbedder(dit_text_dim, dim)
self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0)
self.blocks = nn.ModuleList(

See #13390 (comment).

Comment on lines +584 to +589
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)
batch_size = hidden_states.shape[0]
if timestep.ndim == 0:
timestep = timestep.repeat(batch_size)
timestep_embed = self.time_embed(timestep)
text_mask = encoder_attention_mask.bool()
encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask)

Can you also refactor forward here so that it is better organized, following #13390 (comment)? See for example the QwenImageTransformer2DModel.forward method:

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reorganized parts of forward incrementally; kept the current structure otherwise to avoid unnecessary behavioral churn.

Comment on lines +88 to +98
class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin):
def test_layerwise_casting_memory(self):
pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting memory tests yet.")

def test_layerwise_casting_training(self):
pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting training tests yet.")

def test_group_offloading_with_layerwise_casting(self, *args, **kwargs):
pytest.skip(
"LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin):
def test_layerwise_casting_memory(self):
pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting memory tests yet.")
def test_layerwise_casting_training(self):
pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting training tests yet.")
def test_group_offloading_with_layerwise_casting(self, *args, **kwargs):
pytest.skip(
"LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet."
)
class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin):
pass

Layerwise casting should work if #13390 (comment) is applied.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the layerwise casting training and combined group-offloading/layerwise-casting skips after updating the dtype handling. I kept test_layerwise_casting_memory skipped
because the tiny transformer config does not provide stable peak-memory behavior for that assertion.

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your continued work on this! Left some suggestions that should help LongCatAudioDiTPipeline support model offloading, layerwise casting, etc.

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 11, 2026
@dg845
Copy link
Copy Markdown
Collaborator

dg845 commented Apr 14, 2026

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 14, 2026

Style bot fixed some files and pushed the changes.

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026

@classmethod
@validate_hf_hub_args
def from_pretrained(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a conversion script?
our pipeline should not define from_pretrained method

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a conversion script? our pipeline should not define from_pretrained method

Added it and tested.

@yiyixuxu
Copy link
Copy Markdown
Collaborator

@claude can you help with a review here?

@github-actions
Copy link
Copy Markdown
Contributor

Claude Code is working…

I'll analyze this and get back to you.

View job run

timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)

for i, t in enumerate(timesteps):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add support for a progress bar here? For example, here is how Flux 2 implements a progress bar with self.progress_bar:

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):

This will make it easier to track progress during inference.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add support for a progress bar here? For example, here is how Flux 2 implements a progress bar with self.progress_bar:

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):

This will make it easier to track progress during inference.

Done

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:03<00:00, 13.68it/s]

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation models pipelines size/L PR with diff > 200 LOC tests utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants